{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4998ae3c-74a9-4fee-812f-2fc2513c3915",
   "metadata": {},
   "source": [
    "# Scalable late interaction vectors in Elasticsearch: Token Pooling #\n",
    "\n",
    "In this notebook, we will be looking at how scale search with late interaction models. We will be looking a token pooling - a technique to reduce the dimensionality of the late interaction multi-vectors by clustering similar information. This technique can of course be combined with the other techniques we have discussed in the previous notebooks. \n",
    "\n",
    "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook.  \n",
    "\n",
    "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def to_bit_vectors(embeddings: list) -> list:\n",
    "    return [\n",
    "        np.packbits(np.where(np.array(embedding) > 0, 1, 0))\n",
    "        .astype(np.int8)\n",
    "        .tobytes()\n",
    "        .hex()\n",
    "        for embedding in embeddings\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e0887ca-f194-429b-9afa-208d047a75e4",
   "metadata": {},
   "source": [
    "We will be using the `HierarchicalTokenPooler` from the [colpali-engine](https://github.com/illuin-tech/colpali?tab=readme-ov-file#token-pooling) to reduce the dimensions of our vector.  \n",
    "The authors recommend a `pool_factor=3` for most cases, but you should always tests how it impact the relevancy of your dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9871c9c5-c923-4deb-9f5b-aa6796ba0bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from colpali_engine.compression.token_pooling import HierarchicalTokenPooler\n",
    "\n",
    "pooler = HierarchicalTokenPooler(\n",
    "    pool_factor=3\n",
    ")  # test on your data for a good pool_factor\n",
    "\n",
    "\n",
    "def pool_vectors(embedding: list) -> list:\n",
    "    tensor = torch.tensor(embedding).unsqueeze(0)\n",
    "    pooled = pooler.pool_embeddings(tensor)\n",
    "    return pooled.squeeze(0).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Creating index: searchlabs-colpali-token-pooling\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "load_dotenv(\"elastic.env\")\n",
    "\n",
    "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
    "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
    "INDEX_NAME = \"searchlabs-colpali-token-pooling\"\n",
    "\n",
    "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
    "\n",
    "mappings = {\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"pooled_col_pali_vectors\": {\"type\": \"rank_vectors\", \"element_type\": \"bit\"}\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "if not es.indices.exists(index=INDEX_NAME):\n",
    "    print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
    "    es.indices.create(index=INDEX_NAME, body=mappings)\n",
    "else:\n",
    "    print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
    "\n",
    "\n",
    "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
    "    for attempt in range(1, retries + 1):\n",
    "        try:\n",
    "            return es_client.index(index=index, id=doc_id, document=document)\n",
    "        except Exception as e:\n",
    "            if attempt < retries:\n",
    "                wait_time = initial_backoff * (2 ** (attempt - 1))\n",
    "                print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
    "                time.sleep(wait_time)\n",
    "            else:\n",
    "                print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
    "                raise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "047c33b3344f49328bda552b123c168d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Indexing documents:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "\n",
    "\n",
    "def process_file(file_name, vectors):\n",
    "    if es.exists(index=INDEX_NAME, id=file_name):\n",
    "        return\n",
    "\n",
    "    pooled_vectors = pool_vectors(vectors)\n",
    "\n",
    "    bit_vectors = to_bit_vectors(pooled_vectors)\n",
    "\n",
    "    index_document(\n",
    "        es_client=es,\n",
    "        index=INDEX_NAME,\n",
    "        doc_id=file_name,\n",
    "        document={\"pooled_col_pali_vectors\": bit_vectors},\n",
    "    )\n",
    "\n",
    "\n",
    "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
    "    file_to_multi_vectors = pickle.load(f)\n",
    "\n",
    "with ThreadPoolExecutor(max_workers=10) as executor:\n",
    "    list(\n",
    "        tqdm(\n",
    "            executor.map(\n",
    "                lambda item: process_file(*item), file_to_multi_vectors.items()\n",
    "            ),\n",
    "            total=len(file_to_multi_vectors),\n",
    "            desc=\"Indexing documents\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dfc3713-d649-46db-aa81-171d6d92668e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from colpali_engine.models import ColPali, ColPaliProcessor\n",
    "\n",
    "model_name = \"vidore/colpali-v1.3\"\n",
    "model = ColPali.from_pretrained(\n",
    "    \"vidore/colpali-v1.3\",\n",
    "    torch_dtype=torch.float32,\n",
    "    device_map=\"mps\",  # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
    ").eval()\n",
    "\n",
    "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
    "\n",
    "\n",
    "def create_col_pali_query_vectors(query: str) -> list:\n",
    "    queries = col_pali_processor.process_queries([query]).to(model.device)\n",
    "    with torch.no_grad():\n",
    "        return model(**queries).tolist()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, HTML\n",
    "import os\n",
    "import json\n",
    "\n",
    "DOCUMENT_DIR = \"searchlabs-colpali\"\n",
    "\n",
    "query = \"What do companies use for recruiting?\"\n",
    "query_vector = create_col_pali_query_vectors(query)\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"query\": {\n",
    "        \"script_score\": {\n",
    "            \"query\": {\"match_all\": {}},\n",
    "            \"script\": {\n",
    "                \"source\": \"maxSimDotProduct(params.query_vector, 'pooled_col_pali_vectors')\",\n",
    "                \"params\": {\"query_vector\": query_vector},\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fd9ee4-d7c6-4954-a766-7b06735290ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
    "print(\"Shutting down the kernel to free memory...\")\n",
    "import os\n",
    "\n",
    "os._exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dependecy-test-colpali-blog",
   "language": "python",
   "name": "dependecy-test-colpali-blog"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
